{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pytorch_tabular.utils import load_covertype_dataset\n",
    "from rich import print\n",
    "# %load_ext autoreload\n",
    "# %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "data, cat_col_names, num_col_names, target_col = load_covertype_dataset()\n",
    "train, test = train_test_split(data, random_state=42)\n",
    "train, val = train_test_split(train, random_state=42)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Importing the Library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from pytorch_tabular import TabularModel\n",
    "from pytorch_tabular.models import (\n",
    "    CategoryEmbeddingModelConfig\n",
    ")\n",
    "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig\n",
    "from pytorch_tabular.models.common.heads import LinearHeadConfig\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Test-Time Augumentation\n",
    "\n",
    "Test time augmentation (TTA) is a popular technique in computer vision. TTA aims at boosting the model accuracy by using data augmentation on the inference stage. The idea behind TTA is simple: for each test image, we create multiple versions that are a little different from the original (e.g., cropped or flipped). Next, we predict labels for the test images and created copies and average model predictions over multiple versions of each image. This usually helps to improve the accuracy irrespective of the underlying model.   \n",
    "\n",
    "For more details refer this link: [Test-Time Augmentation for Tabular Data](https://kozodoi.me/python/structured%20data/test-time%20augmentation/2021/09/08/tta-tabular.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "Collapsed": "false",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_config = DataConfig(\n",
    "    target=[\n",
    "        target_col\n",
    "    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    ")\n",
    "trainer_config = TrainerConfig(\n",
    "    batch_size=1024,\n",
    "    max_epochs=100,\n",
    "    early_stopping=\"valid_loss\",  # Monitor valid_loss for early stopping\n",
    "    early_stopping_mode=\"min\",  # Set the mode as min because for val_loss, lower is better\n",
    "    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating\n",
    "    checkpoints=\"valid_loss\",  # Save best checkpoint monitoring val_loss\n",
    "    load_best=True,  # After training, load the best checkpoint\n",
    "    #     progress_bar=\"none\", # Turning off Progress bar\n",
    "    #     trainer_kwargs=dict(\n",
    "    #         enable_model_summary=False # Turning off model summary\n",
    "    #     )\n",
    ")\n",
    "optimizer_config = OptimizerConfig()\n",
    "\n",
    "head_config = LinearHeadConfig(\n",
    "    layers=\"\", dropout=0.1, initialization=\"kaiming\"  # No additional layer in head, just a mapping layer to output_dim\n",
    ").__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)\n",
    "\n",
    "model1_config = CategoryEmbeddingModelConfig(\n",
    "    task=\"classification\",\n",
    "    layers=\"1024-512-512\",  # Number of nodes in each layer\n",
    "    activation=\"LeakyReLU\",  # Activation between each layers\n",
    "    learning_rate=1e-3,\n",
    "    head=\"LinearHead\",  # Linear Head\n",
    "    head_config=head_config,  # Linear Head Config\n",
    ")\n",
    "\n",
    "model_config = CategoryEmbeddingModelConfig(\n",
    "    task=\"classification\",\n",
    "    layers=\"1024-512-512\",  # Number of nodes in each layer\n",
    "    activation=\"LeakyReLU\",  # Activation between each layers\n",
    "    learning_rate=1e-3,\n",
    "    head=\"LinearHead\",  # Linear Head\n",
    "    head_config=head_config,  # Linear Head Config\n",
    ")\n",
    "\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    verbose=False\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory saved_models exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                      </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ CategoryEmbeddingBackbone │  823 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer          │    896 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ head             │ LinearHead                │  3.6 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                     \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ CategoryEmbeddingBackbone │  823 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer          │    896 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head             │ LinearHead                │  3.6 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 827 K                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 827 K                                                                                                \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 3                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 827 K                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 827 K                                                                                                \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 3                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7c740de999644658803a999d1b846413",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n",
       "</pre>\n"
      ],
      "text/plain": [
       "/home/manujosephv/miniconda3/envs/lightning_upgrade/lib/python3.11/site-packages/pytorch_lightning/trainer/connecto\n",
       "rs/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider \n",
       "increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7f56e2806610>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datamodule = tabular_model.prepare_dataloader(train=train, validation=val, seed=42)\n",
    "model = tabular_model.prepare_model(datamodule)\n",
    "tabular_model.train(model, datamodule)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_df = tabular_model.predict(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "tta_pred_df = tabular_model.predict(test, test_time_augmentation=True, num_tta=5, alpha_tta=0.005)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">Original Accuracy: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9450889138262205</span> | TTA Accuracy: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9450062993535417</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "Original Accuracy: \u001b[1;36m0.9450889138262205\u001b[0m | TTA Accuracy: \u001b[1;36m0.9450062993535417\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Calculating metrics\n",
    "orig_acc = accuracy_score(test[target_col].values, pred_df[\"prediction\"].values)\n",
    "tta_acc = accuracy_score(test[target_col].values, tta_pred_df[\"prediction\"].values)\n",
    "print(f\"Original Accuracy: {orig_acc} | TTA Accuracy: {tta_acc}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "ad8d5d2789703c7b1c2f7bfaada1cbd3aa0ac53e2e4e1cae5da195f5520da229"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
